# NSW CATE Analysis

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_predict
import warnings
import sys
from datetime import datetime
warnings.filterwarnings('ignore')

class TeeOutput:
    """Class to write output to both console and file simultaneously."""
    def __init__(self, filename):
        self.terminal = sys.stdout
        self.log = open(filename, 'w')

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)
        self.log.flush()

    def flush(self):
        self.terminal.flush()
        self.log.flush()

    def close(self):
        self.log.close()

class NSWCATEAllocator:
    """NSW CATE allocation algorithm with fixed gamma=0.5 and updated heavy interval threshold."""

    def __init__(self, epsilon=0.1, gamma=0.5, delta=0.05, heavy_multiplier=1.6, random_seed=42):
        self.epsilon = epsilon
        self.gamma = gamma
        self.rho = gamma * np.sqrt(epsilon)
        self.delta = delta
        self.heavy_multiplier = heavy_multiplier  # New parameter for heavy interval threshold
        self.random_seed = random_seed
        np.random.seed(random_seed)

        print(f"NSW CATE Allocation Algorithm")
        print(f"ε = {epsilon}")
        print(f"√ε = {np.sqrt(epsilon):.6f}")
        print(f"γ = {gamma}")
        print(f"ρ = γ√ε = {self.rho:.6f}")
        print(f"Heavy multiplier = {heavy_multiplier}x")
        print(f"δ = {delta}")
        print("="*60)

    def process_nsw_data(self, df, outcome_col='re78', treatment_col='treat'):
        """Process NSW dataset for analysis."""
        print(f"Processing NSW data with {len(df)} observations")
        print(f"Available columns: {list(df.columns)}")

        df_processed = df.copy()

        # Check for required columns
        if treatment_col not in df_processed.columns:
            raise ValueError(f"Missing required treatment column: {treatment_col}")
        if outcome_col not in df_processed.columns:
            raise ValueError(f"Missing required outcome column: {outcome_col}")

        # Set up treatment and outcome
        df_processed['treatment'] = df_processed[treatment_col]
        df_processed['outcome'] = df_processed[outcome_col]

        # Set up baseline earnings
        if 're75' in df_processed.columns:
            df_processed['baseline_earnings'] = df_processed['re75']
        else:
            df_processed['baseline_earnings'] = 0  # Default if no baseline

        # Clean data
        initial_size = len(df_processed)
        df_processed = df_processed.dropna(subset=['outcome', 'treatment'])
        final_size = len(df_processed)

        if initial_size != final_size:
            print(f"Dropped {initial_size - final_size} rows due to missing outcome/treatment")

        print(f"Final dataset: {final_size} individuals")
        print(f"Treatment distribution: {df_processed['treatment'].value_counts().to_dict()}")
        print(f"Outcome (1978 earnings) statistics: mean=${df_processed['outcome'].mean():.0f}, std=${df_processed['outcome'].std():.0f}")

        if 'baseline_earnings' in df_processed.columns:
            print(f"Baseline (1975 earnings) stats: mean=${df_processed['baseline_earnings'].mean():.0f}, std=${df_processed['baseline_earnings'].std():.0f}")

        return df_processed

    def create_demographics_groups(self, df, min_size=6):
        """Create groups by key demographic characteristics in NSW data."""
        print(f"Creating NSW demographics groups")

        # Key NSW demographic variables
        demo_features = ['black', 'hispanic', 'married', 'nodegree']

        # Check which features are available
        available_features = [col for col in demo_features if col in df.columns]

        if not available_features:
            print("No demographic variables found")
            return []

        print(f"Using demographic features: {available_features}")

        # Limit to top 3 features to avoid too many combinations
        if len(available_features) > 3:
            available_features = available_features[:3]

        # Remove rows with missing values in these features
        df_clean = df.dropna(subset=available_features)
        print(f"After removing missing values: {len(df_clean)}/{len(df)} individuals")

        if len(df_clean) == 0:
            return []

        # Get unique combinations
        groups = []
        unique_combinations = df_clean[available_features].drop_duplicates()
        print(f"Found {len(unique_combinations)} unique demographic combinations")

        for combo_idx, (idx, combo) in enumerate(unique_combinations.iterrows()):
            mask = pd.Series(True, index=df.index)
            combo_description = []

            for feature in available_features:
                mask = mask & (df[feature] == combo[feature])
                combo_description.append(f"{feature}={combo[feature]}")

            indices = df[mask].index.tolist()
            combo_id = "_".join(combo_description)

            if len(indices) >= min_size:
                groups.append({
                    'id': combo_id,
                    'indices': indices,
                    'type': 'demographics'
                })

        print(f"Created {len(groups)} demographic groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_age_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on age brackets."""
        print(f"Creating age groups (target: {n_groups})")

        if 'age' not in df.columns:
            print("No age variable found")
            return []

        # Create age-based groups
        age = df['age'].fillna(df['age'].median())

        # Create age brackets
        percentiles = np.linspace(0, 100, n_groups + 1)
        cuts = np.percentile(age, percentiles)
        bins = np.digitize(age, cuts) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'age_group_{i}',
                    'indices': indices,
                    'type': 'age'
                })

        print(f"Created {len(groups)} age groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_education_groups(self, df, min_size=6):
        """Create groups based on education levels."""
        print(f"Creating education groups")

        if 'education' not in df.columns:
            print("No education variable found")
            return []

        groups = []
        for education_level in df['education'].unique():
            if pd.isna(education_level):
                continue

            indices = df[df['education'] == education_level].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'education_{education_level}_years',
                    'indices': indices,
                    'type': 'education'
                })

        print(f"Created {len(groups)} education groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_baseline_earnings_groups(self, df, n_groups=30, min_size=6):
        """Create groups based on 1975 baseline earnings."""
        print(f"Creating baseline earnings groups (target: {n_groups})")

        if 'baseline_earnings' not in df.columns or df['baseline_earnings'].sum() == 0:
            print("No baseline earnings data available")
            return []

        # Create earnings-based groups
        earnings = df['baseline_earnings'].fillna(0)  # Fill NaN with 0 for unemployed

        # Create earnings brackets including zero earners
        if (earnings == 0).mean() > 0.3:  # If >30% have zero earnings, create separate zero group
            # Create one group for zero earners
            zero_earners = df.index[earnings == 0].tolist()
            groups = []
            if len(zero_earners) >= min_size:
                groups.append({
                    'id': 'zero_earnings_1975',
                    'indices': zero_earners,
                    'type': 'baseline_earnings'
                })

            # Create groups for positive earners
            positive_earnings = earnings[earnings > 0]
            if len(positive_earnings) > 0:
                percentiles = np.linspace(0, 100, n_groups)
                cuts = np.percentile(positive_earnings, percentiles)

                for i in range(len(cuts) - 1):
                    mask = (earnings > cuts[i]) & (earnings <= cuts[i + 1])
                    indices = df.index[mask].tolist()
                    if len(indices) >= min_size:
                        groups.append({
                            'id': f'earnings_1975_bracket_{i}',
                            'indices': indices,
                            'type': 'baseline_earnings'
                        })
        else:
            # Standard percentile groups
            percentiles = np.linspace(0, 100, n_groups + 1)
            cuts = np.percentile(earnings, percentiles)
            bins = np.digitize(earnings, cuts) - 1

            groups = []
            for i in range(n_groups):
                indices = df.index[bins == i].tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'earnings_1975_bracket_{i}',
                        'indices': indices,
                        'type': 'baseline_earnings'
                    })

        print(f"Created {len(groups)} baseline earnings groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_race_ethnicity_groups(self, df, min_size=6):
        """Create groups based on race/ethnicity combinations."""
        print(f"Creating race/ethnicity groups")

        # Create race/ethnicity categories
        def get_race_ethnicity(row):
            if row.get('black', 0) == 1:
                return 'black'
            elif row.get('hispanic', 0) == 1:
                return 'hispanic'
            else:
                return 'white_other'

        if 'black' in df.columns or 'hispanic' in df.columns:
            df['race_ethnicity'] = df.apply(get_race_ethnicity, axis=1)

            groups = []
            for race_eth in df['race_ethnicity'].unique():
                indices = df[df['race_ethnicity'] == race_eth].index.tolist()
                if len(indices) >= min_size:
                    groups.append({
                        'id': f'race_{race_eth}',
                        'indices': indices,
                        'type': 'race_ethnicity'
                    })

            print(f"Created {len(groups)} race/ethnicity groups")
            balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
            return balanced_groups
        else:
            print("No race/ethnicity variables found")
            return []

    def create_employment_status_groups(self, df, min_size=6):
        """Create groups based on 1975 employment status."""
        print(f"Creating employment status groups")

        if 'baseline_earnings' not in df.columns:
            print("No baseline earnings data for employment status")
            return []

        # Define employment status based on 1975 earnings
        def get_employment_status(earnings):
            if pd.isna(earnings) or earnings == 0:
                return 'unemployed_1975'
            elif earnings < 5000:  # Low earnings threshold for 1975
                return 'low_earnings_1975'
            else:
                return 'higher_earnings_1975'

        df['employment_status'] = df['baseline_earnings'].apply(get_employment_status)

        groups = []
        for status in df['employment_status'].unique():
            indices = df[df['employment_status'] == status].index.tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'employment_{status}',
                    'indices': indices,
                    'type': 'employment_status'
                })

        print(f"Created {len(groups)} employment status groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_causal_forest_groups(self, df, n_groups=30, min_size=6):
        """Create groups using Random Forest to predict treatment effects."""
        print(f"Creating causal forest groups (target: {n_groups})")

        # Use NSW covariates
        feature_cols = ['age', 'education', 'black', 'hispanic', 'married', 'nodegree', 're75']
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for causal forest")
            return []

        X = df[available_features].copy()

        # Handle missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Train separate models
        treated_mask = df['treatment'] == 1
        control_mask = df['treatment'] == 0

        if treated_mask.sum() < 5 or control_mask.sum() < 5:
            print("Not enough treated or control observations for causal forest")
            return []

        rf_treated = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)
        rf_control = RandomForestRegressor(n_estimators=100, random_state=self.random_seed)

        rf_treated.fit(X[treated_mask], df.loc[treated_mask, 'outcome'])
        rf_control.fit(X[control_mask], df.loc[control_mask, 'outcome'])

        # Predict CATE and cluster
        pred_cate = rf_treated.predict(X) - rf_control.predict(X)
        cluster_features = np.column_stack([X.values, pred_cate.reshape(-1, 1)])
        cluster_features = StandardScaler().fit_transform(cluster_features)

        labels = KMeans(n_clusters=n_groups, random_state=self.random_seed).fit_predict(cluster_features)

        groups = []
        for i in range(n_groups):
            indices = df.index[labels == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'causal_forest_{i}',
                    'indices': indices,
                    'type': 'causal_forest'
                })

        print(f"Created {len(groups)} causal forest groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def create_propensity_groups(self, df, n_groups=50, min_size=6):
        """Create groups based on propensity score strata."""
        print(f"Creating propensity score groups (target: {n_groups})")

        # Use NSW covariates
        feature_cols = ['age', 'education', 'black', 'hispanic', 'married', 'nodegree', 're75']
        available_features = [col for col in feature_cols if col in df.columns]

        if not available_features:
            print("No features available for propensity scoring")
            return []

        X = df[available_features].copy()

        # Handle missing values
        for col in X.columns:
            if X[col].isna().any():
                if X[col].dtype in ['int64', 'float64']:
                    X[col] = X[col].fillna(X[col].median())
                else:
                    X[col] = X[col].fillna(X[col].mode()[0] if len(X[col].mode()) > 0 else 0)

        # Get propensity scores
        try:
            prop_scores = cross_val_predict(
                LogisticRegression(random_state=self.random_seed, max_iter=1000),
                X, df['treatment'], method='predict_proba', cv=5
            )[:, 1]
        except Exception as e:
            print(f"Error computing propensity scores: {e}")
            return []

        # Create strata
        quantiles = np.linspace(0, 1, n_groups + 1)
        bins = np.digitize(prop_scores, np.quantile(prop_scores, quantiles)) - 1

        groups = []
        for i in range(n_groups):
            indices = df.index[bins == i].tolist()
            if len(indices) >= min_size:
                groups.append({
                    'id': f'propensity_{i}',
                    'indices': indices,
                    'type': 'propensity'
                })

        print(f"Created {len(groups)} propensity groups")
        balanced_groups = self._ensure_balance_and_compute_cate(df, groups)
        return balanced_groups

    def _ensure_balance_and_compute_cate(self, df, groups):
        """Ensure treatment balance and compute group CATE."""
        balanced_groups = []

        for group in groups:
            group_df = df.loc[group['indices']]

            treatment_rate = group_df['treatment'].mean()
            n_treated = group_df['treatment'].sum()
            n_control = len(group_df) - n_treated

            if not (0.15 <= treatment_rate <= 0.85 and n_treated >= 3 and n_control >= 3):
                continue

            treated_outcomes = group_df[group_df['treatment'] == 1]['outcome']
            control_outcomes = group_df[group_df['treatment'] == 0]['outcome']
            cate = treated_outcomes.mean() - control_outcomes.mean()

            balanced_groups.append({
                'id': group['id'],
                'indices': group['indices'],
                'size': len(group_df),
                'treatment_rate': treatment_rate,
                'n_treated': int(n_treated),
                'n_control': int(n_control),
                'cate': cate,
                'type': group['type']
            })

        return balanced_groups

    def normalize_cates(self, groups):
        """Normalize CATE values to [0,1]."""
        cates = [g['cate'] for g in groups]
        min_cate, max_cate = min(cates), max(cates)

        if max_cate > min_cate:
            for group in groups:
                group['normalized_cate'] = (group['cate'] - min_cate) / (max_cate - min_cate)
        else:
            for group in groups:
                group['normalized_cate'] = 0.5

        print(f"CATE normalization: [{min_cate:.0f}, {max_cate:.0f}] → [0, 1]")
        return groups

    def plot_cate_distribution(self, groups, title_suffix=""):
        """Plot CATE distribution."""
        original_cates = [g['cate'] for g in groups]
        normalized_cates = [g['normalized_cate'] for g in groups]

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

        ax1.hist(original_cates, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
        ax1.set_xlabel('Original CATE ($ earnings effect)')
        ax1.set_ylabel('Frequency')
        ax1.set_title(f'Original CATE Distribution{title_suffix}')
        ax1.grid(True, alpha=0.3)

        ax2.hist(normalized_cates, bins=15, alpha=0.7, color='lightcoral', edgecolor='black')
        ax2.set_xlabel('Normalized CATE (τ)')
        ax2.set_ylabel('Frequency')
        ax2.set_title(f'Normalized CATE Distribution{title_suffix}')
        ax2.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

    def estimate_tau(self, true_tau, accuracy):
        """Estimate tau using Hoeffding's inequality with Bernoulli samples."""
        sample_size = int(np.ceil(np.log(2/self.delta) / (2 * accuracy**2)))
        samples = np.random.binomial(1, true_tau, sample_size)
        return np.mean(samples), sample_size

    def run_single_trial(self, groups, epsilon_val, trial_seed):
        """Run allocation algorithm for single trial with fixed gamma."""
        np.random.seed(self.random_seed + trial_seed)

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])
        rho = self.gamma * np.sqrt(epsilon_val)  # Use fixed gamma

        # Estimate all tau values using rho accuracy
        tau_estimates_rho = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, rho)
            tau_estimates_rho.append(estimate)
        tau_estimates_rho = np.array(tau_estimates_rho)

        # Also estimate using epsilon accuracy for comparison
        tau_estimates_eps = []
        for tau in tau_true:
            estimate, _ = self.estimate_tau(tau, epsilon_val)
            tau_estimates_eps.append(estimate)
        tau_estimates_eps = np.array(tau_estimates_eps)

        results = []

        for K in range(1, n_groups):
            optimal_indices = np.argsort(tau_true)[-K:]
            optimal_value = np.sum(tau_true[optimal_indices])

            rho_indices = np.argsort(tau_estimates_rho)[-K:]
            rho_value = np.sum(tau_true[rho_indices])

            eps_indices = np.argsort(tau_estimates_eps)[-K:]
            eps_value = np.sum(tau_true[eps_indices])

            rho_ratio = rho_value / optimal_value if optimal_value > 0 else 0
            eps_ratio = eps_value / optimal_value if optimal_value > 0 else 0
            rho_success = rho_ratio >= (1 - epsilon_val)
            eps_success = eps_ratio >= (1 - epsilon_val)

            tau_k_est = tau_estimates_rho[rho_indices[0]]
            a2_lower = tau_k_est
            a2_upper = tau_k_est + 2 * rho
            units_in_a2 = np.sum((tau_estimates_rho >= a2_lower) & (tau_estimates_rho <= a2_upper))
            expected_a2 = 2 * rho * n_groups
            # Updated heavy interval detection with 1.6x multiplier
            is_heavy = units_in_a2 > self.heavy_multiplier * expected_a2

            results.append({
                'K': K,
                'optimal_value': optimal_value,
                'rho_value': rho_value,
                'eps_value': eps_value,
                'rho_ratio': rho_ratio,
                'eps_ratio': eps_ratio,
                'rho_success': rho_success,
                'eps_success': eps_success,
                'is_heavy': is_heavy,
                'tau_k_est': tau_k_est,
                'units_in_a2': units_in_a2
            })

        return results, tau_estimates_rho

    def find_recovery_units(self, K, tau_true, tau_estimates, epsilon_val):
        """Find minimum units needed to achieve 1-epsilon performance."""
        n_groups = len(tau_true)

        # Original allocation (using rho estimates)
        rho_indices = np.argsort(tau_estimates)[-K:]
        optimal_value = np.sum(tau_true[np.argsort(tau_true)[-K:]])

        # Remaining candidates (sorted by estimate, best first)
        remaining_indices = np.argsort(tau_estimates)[:-K][::-1]

        # Test adding 1 to 10 additional units
        for extra in range(1, 11):
            if extra > len(remaining_indices):
                break

            expanded_indices = np.concatenate([rho_indices, remaining_indices[:extra]])
            expanded_value = np.sum(tau_true[expanded_indices])

            if expanded_value / optimal_value >= (1 - epsilon_val):
                return extra

        return None  # Need more than 10 units

    def find_closest_working_budget(self, failed_K, trial_results):
        """Find closest budget that works for a failed budget."""
        working_budgets = [r['K'] for r in trial_results if r['rho_success']]

        if not working_budgets:
            return None, None

        # Distance to any working budget (either direction)
        distances_any = [abs(K - failed_K) for K in working_budgets]
        min_distance_any = min(distances_any)

        # Distance to smaller working budget (underspending)
        smaller_working = [K for K in working_budgets if K < failed_K]
        if smaller_working:
            min_distance_smaller = failed_K - max(smaller_working)
        else:
            min_distance_smaller = None

        return min_distance_any, min_distance_smaller

    def analyze_method(self, groups, epsilon_val, n_trials=30):
        """Analyze single method with fixed gamma and updated heavy threshold."""
        print(f"\nAnalyzing {len(groups)} groups with ε={epsilon_val}, γ={self.gamma}")

        n_groups = len(groups)
        tau_true = np.array([g['normalized_cate'] for g in groups])

        trial_data = []

        for trial in range(n_trials):
            print(f"Trial {trial + 1}/{n_trials}...")

            # Run single trial
            trial_results, tau_estimates = self.run_single_trial(groups, epsilon_val, trial)

            # Analyze failures
            failed_results = [r for r in trial_results if not r['rho_success']]
            failed_budgets = [r['K'] for r in failed_results]

            # Check which failed budgets are heavy with true tau values
            failed_heavy_estimated = []
            failed_heavy_true = []
            rho = self.gamma * np.sqrt(epsilon_val)

            for failed_result in failed_results:
                K = failed_result['K']
                # Heavy with estimated values (already computed)
                failed_heavy_estimated.append(failed_result['is_heavy'])

                # Check heavy with true tau values
                tau_k_true = tau_true[np.argsort(tau_true)[-K:]][0]  # True smallest in top-K
                a2_lower_true = tau_k_true
                a2_upper_true = tau_k_true + 2 * rho
                units_in_a2_true = np.sum((tau_true >= a2_lower_true) & (tau_true <= a2_upper_true))
                expected_a2_true = 2 * rho * n_groups
                is_heavy_true = units_in_a2_true > self.heavy_multiplier * expected_a2_true
                failed_heavy_true.append(is_heavy_true)

            # Print trial summary
            print(f"  Failed budgets: {failed_budgets}")

            # Print heavy vectors
            if len(failed_budgets) > 0:
                estimated_clean = [bool(x) for x in failed_heavy_estimated]
                true_clean = [bool(x) for x in failed_heavy_true]
                print(f"  HEAVY INTERVALS - Estimated: {estimated_clean}")
                print(f"  HEAVY INTERVALS - True τ_K:   {true_clean}")

            # Count total heavy intervals and failed budgets in heavy intervals
            total_heavy = sum(r['is_heavy'] for r in trial_results)
            failed_heavy = sum(r['is_heavy'] for r in failed_results)

            # Recovery analysis
            recovery_units = []
            distances_to_working_any = []
            distances_to_working_smaller = []

            for failed_result in failed_results:
                K = failed_result['K']

                # Find recovery units needed
                recovery = self.find_recovery_units(K, tau_true, tau_estimates, epsilon_val)
                if recovery is not None:
                    recovery_units.append(recovery)

                # Find distances to closest working budgets
                distance_any, distance_smaller = self.find_closest_working_budget(K, trial_results)
                if distance_any is not None:
                    distances_to_working_any.append(distance_any)
                if distance_smaller is not None:
                    distances_to_working_smaller.append(distance_smaller)

            trial_info = {
                'trial': trial,
                'failed_budgets': failed_budgets,
                'num_failures': len(failed_results),
                'total_heavy': total_heavy,
                'failed_heavy': failed_heavy,
                'failed_heavy_estimated': failed_heavy_estimated,
                'failed_heavy_true': failed_heavy_true,
                'recovery_units': recovery_units,
                'distances_to_working_any': distances_to_working_any,
                'distances_to_working_smaller': distances_to_working_smaller
            }

            trial_data.append(trial_info)

            print(f"  Failures: {len(failed_results)}, Total heavy: {total_heavy}, Failed heavy: {failed_heavy}")
            if recovery_units:
                print(f"  Recovery units: μ={np.mean(recovery_units):.1f}, med={np.median(recovery_units):.0f}, max={np.max(recovery_units)}")
            if distances_to_working_any:
                print(f"  Distance any: μ={np.mean(distances_to_working_any):.1f}, med={np.median(distances_to_working_any):.0f}, max={np.max(distances_to_working_any)}")
            if distances_to_working_smaller:
                print(f"  Distance smaller: μ={np.mean(distances_to_working_smaller):.1f}, med={np.median(distances_to_working_smaller):.0f}, max={np.max(distances_to_working_smaller)}")
            else:
                print(f"  Distance smaller: No smaller working budgets found")

        return trial_data

    def print_method_summary(self, method_name, trial_data, n_groups, epsilon_val):
        """Print summary statistics for a method."""
        budget_10pct_threshold = max(1, int(0.1 * n_groups))

        print(f"\n{'='*100}")
        print(f"SUMMARY - {method_name} - ε={epsilon_val} - {n_groups} GROUPS")
        print("="*100)
        print(f"{'Fail μ':<7} {'Fail σ':<7} {'FailR% μ':<9} {'FailR% σ':<9} {'TotHvy':<8} {'FailHvy':<9} {'Rec μ':<7} {'Rec med':<8} {'Rec max':<8} {'DAny μ':<8} {'DAny σ':<10} {'DAny max':<10} {'DSmall μ':<10} {'DSmall σ':<12} {'DSmall max':<12}")
        print("-"*120)

        # Aggregate statistics across all trials - ALL BUDGETS
        all_failures = [t['num_failures'] for t in trial_data]
        all_total_heavy = [t['total_heavy'] for t in trial_data]
        all_failed_heavy = [t['failed_heavy'] for t in trial_data]
        all_recovery = []
        all_distances_any = []
        all_distances_smaller = []

        for t in trial_data:
            all_recovery.extend(t['recovery_units'])
            all_distances_any.extend(t['distances_to_working_any'])
            all_distances_smaller.extend(t['distances_to_working_smaller'])

        avg_failures = np.mean(all_failures)
        std_failures = np.std(all_failures)
        avg_failure_rate = avg_failures / (n_groups - 1) * 100
        std_failure_rate = std_failures / (n_groups - 1) * 100
        avg_total_heavy = np.mean(all_total_heavy)
        avg_failed_heavy = np.mean(all_failed_heavy)

        # Recovery statistics
        if all_recovery:
            recovery_mean = np.mean(all_recovery)
            recovery_med = np.median(all_recovery)
            recovery_max = np.max(all_recovery)
        else:
            recovery_mean = recovery_med = recovery_max = np.nan

        # Distance statistics - any direction
        if all_distances_any:
            distance_any_mean = np.mean(all_distances_any)
            distance_any_std = np.std(all_distances_any)
            distance_any_max = np.max(all_distances_any)
        else:
            distance_any_mean = distance_any_std = distance_any_max = np.nan

        # Distance statistics - smaller only
        if all_distances_smaller:
            distance_smaller_mean = np.mean(all_distances_smaller)
            distance_smaller_std = np.std(all_distances_smaller)
            distance_smaller_max = np.max(all_distances_smaller)
        else:
            distance_smaller_mean = distance_smaller_std = distance_smaller_max = np.nan

        print(f"{avg_failures:<7.1f} {std_failures:<7.1f} {avg_failure_rate:<9.1f} {std_failure_rate:<9.1f} {avg_total_heavy:<8.1f} {avg_failed_heavy:<9.1f} "
              f"{recovery_mean:<7.1f} {recovery_med:<8.0f} {recovery_max:<8.0f} "
              f"{distance_any_mean:<8.1f} {distance_any_std:<10.1f} {distance_any_max:<10.0f} "
              f"{distance_smaller_mean:<10.1f} {distance_smaller_std:<12.1f} {distance_smaller_max:<12.0f}")

        return {
            'avg_failures': avg_failures,
            'failure_rate_pct': avg_failure_rate,
            'avg_recovery': recovery_mean,
            'n_groups': n_groups
        }


def run_comprehensive_nsw_analysis(df_nsw, epsilon_values=None, n_trials=30, log_file=None):
    """Run comprehensive NSW analysis with all methods, fixed gamma=0.5, and 1.6x heavy threshold."""

    if epsilon_values is None:
        epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    # Set up logging
    if log_file is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = f"nsw_comprehensive_analysis_gamma05_{timestamp}.txt"

    # Redirect output to both console and file
    original_stdout = sys.stdout
    tee = TeeOutput(log_file)
    sys.stdout = tee

    try:
        print("COMPREHENSIVE NSW ANALYSIS - ALL METHODS, FIXED γ=0.5, HEAVY THRESHOLD=1.6x")
        print(f"Log file: {log_file}")
        print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        print("="*100)

        # Define all NSW-specific grouping methods
        methods = [
            ('Demographics', lambda allocator, df: allocator.create_demographics_groups(df, min_size=6)),
            ('Race/Ethnicity', lambda allocator, df: allocator.create_race_ethnicity_groups(df, min_size=6)),
            ('Age Groups', lambda allocator, df: allocator.create_age_groups(df, n_groups=30, min_size=6)),
            ('Education', lambda allocator, df: allocator.create_education_groups(df, min_size=6)),
            ('Baseline Earnings', lambda allocator, df: allocator.create_baseline_earnings_groups(df, n_groups=30, min_size=6)),
            ('Employment Status', lambda allocator, df: allocator.create_employment_status_groups(df, min_size=6)),
            ('Causal Forest 30', lambda allocator, df: allocator.create_causal_forest_groups(df, n_groups=30, min_size=6)),
            ('Causal Forest 50', lambda allocator, df: allocator.create_causal_forest_groups(df, n_groups=50, min_size=6)),
            ('Propensity Score', lambda allocator, df: allocator.create_propensity_groups(df, n_groups=50, min_size=6))
        ]

        all_results = {}

        for method_name, method_func in methods:
            print(f"\n{'='*120}")
            print(f"ANALYZING NSW METHOD: {method_name}")
            print("="*120)

            method_results = []

            for eps in epsilon_values:
                print(f"\n{'='*100}")
                print(f"METHOD: {method_name} | EPSILON = {eps}")
                print("="*100)

                # Initialize allocator with fixed gamma=0.5 and 1.6x heavy threshold
                allocator = NSWCATEAllocator(epsilon=eps, gamma=0.5, heavy_multiplier=1.6)
                df_processed = allocator.process_nsw_data(df_nsw)

                try:
                    # Create groups using this method
                    groups = method_func(allocator, df_processed)

                    if len(groups) < 3:
                        print(f"Too few groups ({len(groups)}) for {method_name} with ε = {eps} - skipping")
                        continue

                    groups = allocator.normalize_cates(groups)

                    # Show CATE distribution
                    allocator.plot_cate_distribution(groups, f" ({method_name}, ε={eps})")

                    # Run analysis for this epsilon and method
                    trial_data = allocator.analyze_method(groups, eps, n_trials)

                    # Print method summary
                    stats = allocator.print_method_summary(method_name, trial_data, len(groups), eps)

                    epsilon_result = {
                        'method': method_name,
                        'epsilon': eps,
                        'sqrt_epsilon': np.sqrt(eps),
                        'gamma': 0.5,
                        'rho': 0.5 * np.sqrt(eps),
                        'groups': groups,
                        'trial_data': trial_data,
                        'stats': stats
                    }

                    method_results.append(epsilon_result)

                except Exception as e:
                    print(f"Error with {method_name} at ε = {eps}: {e}")
                    continue

            all_results[method_name] = method_results

            # Add method-specific summary table after all epsilons for this method
            if method_results:
                print(f"\n{'='*120}")
                print(f"METHOD SUMMARY - {method_name} - ALL EPSILON VALUES")
                print("="*120)
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for eps_result in method_results:
                    eps = eps_result['epsilon']
                    sqrt_eps = eps_result['sqrt_epsilon']
                    gamma = eps_result['gamma']
                    rho = eps_result['rho']
                    n_groups = len(eps_result['groups'])
                    stats = eps_result['stats']

                    print(f"{eps:<8} {sqrt_eps:<10.6f} {gamma:<6} {rho:<10.6f} "
                          f"{n_groups:<8} {stats['avg_failures']:<8.1f} {stats['failure_rate_pct']:<8.1f} "
                          f"{stats['avg_recovery']:<8.1f}")
                print("="*120)

        # Create comprehensive summary across all methods and epsilon values
        print(f"\n{'='*200}")
        print("COMPREHENSIVE SUMMARY - ALL NSW METHODS AND EPSILON VALUES")
        print("="*200)

        # Create summary table
        summary_data = []

        for method_name, method_results in all_results.items():
            if not method_results:
                continue

            print(f"\n{'-'*100}")
            print(f"NSW METHOD: {method_name}")
            print("-"*100)

            for eps_result in method_results:
                eps = eps_result['epsilon']
                sqrt_eps = eps_result['sqrt_epsilon']
                gamma = eps_result['gamma']
                rho = eps_result['rho']
                n_groups = len(eps_result['groups'])
                stats = eps_result['stats']

                summary_data.append({
                    'method': method_name,
                    'epsilon': eps,
                    'sqrt_eps': sqrt_eps,
                    'gamma': gamma,
                    'rho': rho,
                    'avg_failures': stats['avg_failures'],
                    'failure_rate_pct': stats['failure_rate_pct'],
                    'avg_recovery': stats['avg_recovery'],
                    'n_groups': stats['n_groups']
                })

            # Print method-specific table
            method_data = [d for d in summary_data if d['method'] == method_name]
            if method_data:
                print(f"{'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
                print("-" * 80)

                for data in method_data:
                    print(f"{data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                          f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                          f"{data['avg_recovery']:<8.1f}")

        # Overall summary table
        print(f"\n{'='*200}")
        print("OVERALL SUMMARY TABLE - ALL NSW METHODS COMBINED")
        print("="*200)
        print(f"{'Method':<18} {'ε':<8} {'√ε':<10} {'γ':<6} {'ρ':<10} {'Groups':<8} {'Fail μ':<8} {'FailR%':<8} {'Rec μ':<8}")
        print("-" * 100)

        for data in summary_data:
            print(f"{data['method']:<18} {data['epsilon']:<8} {data['sqrt_eps']:<10.6f} {data['gamma']:<6} {data['rho']:<10.6f} "
                  f"{data['n_groups']:<8} {data['avg_failures']:<8.1f} {data['failure_rate_pct']:<8.1f} "
                  f"{data['avg_recovery']:<8.1f}")

        # NSW-specific analysis insights
        print(f"\n{'='*100}")
        print("KEY INSIGHTS FOR NSW DATASET")
        print("="*100)

        # Find best and worst performing methods for NSW
        if summary_data:
            # Average performance across all epsilon values per method
            method_performance = {}
            for method_name in all_results.keys():
                method_data = [d for d in summary_data if d['method'] == method_name]
                if method_data:
                    avg_failure_rate = np.mean([d['failure_rate_pct'] for d in method_data])
                    method_performance[method_name] = avg_failure_rate

            if method_performance:
                best_method = min(method_performance, key=method_performance.get)
                worst_method = max(method_performance, key=method_performance.get)

                print(f"BEST PERFORMING NSW METHOD: {best_method}")
                print(f"  Average failure rate: {method_performance[best_method]:.1f}%")

                print(f"\nWORST PERFORMING NSW METHOD: {worst_method}")
                print(f"  Average failure rate: {method_performance[worst_method]:.1f}%")

                print(f"\nNSW METHOD RANKING (by average failure rate):")
                sorted_methods = sorted(method_performance.items(), key=lambda x: x[1])
                for i, (method, rate) in enumerate(sorted_methods, 1):
                    print(f"  {i}. {method}: {rate:.1f}%")

        # Effect of epsilon on NSW
        print(f"\nEFFECT OF EPSILON ON NSW:")
        epsilon_performance = {}
        for eps in epsilon_values:
            eps_data = [d for d in summary_data if d['epsilon'] == eps]
            if eps_data:
                avg_failure_rate = np.mean([d['failure_rate_pct'] for d in eps_data])
                epsilon_performance[eps] = avg_failure_rate

        if epsilon_performance:
            print(f"{'Epsilon':<10} {'Avg Failure Rate':<15} {'ρ = 0.5√ε':<12}")
            print("-" * 40)
            for eps in sorted(epsilon_performance.keys()):
                rho = 0.5 * np.sqrt(eps)
                print(f"{eps:<10} {epsilon_performance[eps]:<15.1f} {rho:<12.6f}")

        return all_results, summary_data

    finally:
        # Restore original stdout and close log file
        sys.stdout = original_stdout
        tee.close()


# Example usage for NSW dataset
if __name__ == "__main__":
    # Load NSW dataset and run the analysis
    df_nsw = pd.read_stata('nsw.dta')

    # Run comprehensive NSW analysis with same parameters as TUP
    epsilon_values = [0.001, 0.005, 0.01, 0.02, 0.05, 0.1, 0.15, 0.2]

    results, summary = run_comprehensive_nsw_analysis(
        df_nsw,
        epsilon_values=epsilon_values,
        n_trials=30,
        log_file="nsw_comprehensive_analysis_gamma05_heavy16.txt"
    )